10. 训练网络
07 训练网络 V1
交叉熵损失
PyTorch 文档提到,交叉熵损失函数包括两步:
- 首先对看到的任何输出应用 softmax 函数
- 然后应用 NLLLoss 负对数似然损失
接着返回一批数据的平均损失。因为交叉熵损失会应用 softmax 函数,所以我们不需要在模型定义的 forward
函数中应用 softmax 函数;我们还可以采用另一种方式。
另一种方式
我们可以将 softmax 步骤和 NLLLoss 步骤分开处理。
- 在模型的
forward
函数中,我们可以向输出x
应用 softmax 激活函数。
...
...
# a softmax layer to convert 10 outputs into a distribution of class probabilities
x = F.log_softmax(x, dim=1)
return x
- 然后,在定义损失函数时应用 NLLLoss
# cross entropy loss combines softmax and nn.NLLLoss() in one single class
# here, we've separated them
criterion = nn.NLLLoss()
这样会将常规的 criterion = nn.CrossEntropy()
分成两步:softmax 和NLLLoss;如果你希望模型输出是类别概率,而不是类别分数的话,可以采取这种方式。